PyTorch 如何利用多个损失开展深度神经网络的训练过程【持续更新】 您所在的位置:网站首页 多个神经网络 联合训练 PyTorch 如何利用多个损失开展深度神经网络的训练过程【持续更新】

PyTorch 如何利用多个损失开展深度神经网络的训练过程【持续更新】

2024-06-30 20:24| 来源: 网络整理| 查看: 265

咱们直接进入正题!

def train(model, loss1, loss2, train_dataloader, optimizer_loss1, optimizer_loss2, epoch, writer, device_num): model.train() device = torch.device("cuda:"+str(device_num)) correct = 0 value_loss1 = 0 value_loss2 = 0 result_loss = 0 for data_nnl in train_dataloader: data, target = data_nnl target = target.long() if torch.cuda.is_available(): data = data.to(device) target = target.to(device) optimizer_loss1.zero_grad() optimizer_loss2.zero_grad() output = model(data) classifier_output = F.log_softmax(output[1], dim=1) value_loss1_batch = loss1(classifier_output, target) //第一个损失项 value_loss2_batch = loss2(output[0], target) //第二个损失项 weight_loss2 = 0.005 result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch result_loss_batch.backward() optimizer_loss1.step() for param in loss2.parameters(): param.grad.data *= (1. / weight_loss2) optimizer_loss2.step()

我这里采用的是两项损失,loss1用于优化网络权重,loss2用于优化中心矢量,二者均是可训练的超参,因此包含两个优化器,如果多个损失项均用于优化网络权重,那么只采用一个优化器即可,如下所示

def train(model, loss1, loss2, train_dataloader, optimizer, epoch, writer, device_num): model.train() device = torch.device("cuda:"+str(device_num)) correct = 0 value_loss1 = 0 value_loss2 = 0 result_loss = 0 for data_nnl in train_dataloader: data, target = data_nnl target = target.long() if torch.cuda.is_available(): data = data.to(device) target = target.to(device) optimizer.zero_grad() output = model(data) classifier_output = F.log_softmax(output[1], dim=1) value_loss1_batch = loss1(classifier_output, target) //第一个损失项 value_loss2_batch = loss2(output[0], target) //第二个损失项 weight_loss2 = 0.005 result_loss_batch = value_loss1_batch + weight_loss2 * value_loss2_batch result_loss_batch.backward() optimizer.step()

详细代码,请翻阅我们的论文,代码已开源,开源链接可查论文摘要。

若该经验贴对您科研、学习有所帮助,欢迎您引用我们的论文。

[1] X. Fu et al., "Semi-Supervised Specific Emitter Identification Method Using Metric-Adversarial Training," in IEEE Internet of Things Journal, vol. 10, no. 12, pp. 10778-10789, 15 June15, 2023, doi: 10.1109/JIOT.2023.3240242.

[2] X. Fu et al., "Semi-Supervised Specific Emitter Identification via Dual Consistency Regularization," in IEEE Internet of Things Journal, doi: 10.1109/JIOT.2023.3281668.



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有